import torch
import math
import random
import numpy as np
import ShiftingWindowSetting as sw
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib import cm
from torch import nn


class DeepCCG(sw.CLLearningAlgo):

    per_class_mem = {}

    def __init__(self, args, use_task_inc_loss=True, mem_size=1000, mem_batch_size=10, min_num_mem_points=1,
                 num_of_iter_new_mem=5, mem_lr=0.1, fit_mem_batch_size=64, diag_approx=True,  # 5
                 new_mem_diag_approx=False, s_coeff=1, fixed_var=True, calc_stat_batch_size=1024,  # 128
                 mem_select_KLD=True, store_old_means=False):

        # classification_head is used to ensure the embedding neural network is not created with a classification head
        super().__init__(args=args, classification_head=False)
        self.mem_size = mem_size
        self.mem_batch_size = mem_batch_size
        self.seen_classes = []
        self.min_num_mem_points = min_num_mem_points
        self.use_task_inc_loss = use_task_inc_loss
        self.fixed_var = fixed_var
        self.mem_select_KLD = mem_select_KLD if fixed_var else False
        self.calc_stat_batch_size = calc_stat_batch_size
        for name, module in self.model.named_modules():
            if "linear" in name:
                self.rep_size = module.in_features

        self.rep_size = self.model_width * 8  # 256 # * 16  # this is a temp test

        # these are used in evauation to store the memory buffer's statistics
        self.mem_means = None
        self.mem_precisions = None
        self.mem_log_sqrt_dets = None

        # used for fitting new memory
        self.num_of_iter_new_mem = num_of_iter_new_mem  # can perform random sample for new mem by having this as 0
        self.mem_lr = mem_lr
        self.fit_mem_batch_size = fit_mem_batch_size
        self.class_sorted = {}

        # these ares used to state whether to use a diagonal approximation of the covariance firstly in the method
        # as a whole and secondly whether to do so in the memory selection part
        self.diag_approx = diag_approx
        # not used at the moment perhaps try out when rest of code is good
        self.new_mem_diag_approx = new_mem_diag_approx
        self.det_bias = 0.0  # this is to make sure that the det dose not underflow

        # these are the priors to our model params
        self.s_coeff = s_coeff
        self.S_init = s_coeff*torch.ones(self.rep_size) if (self.diag_approx or self.fixed_var) \
            else s_coeff*torch.eye(self.rep_size)
        self.S_init = self.S_init.to(self.device)  # perhaps for lagers datasets must only send to device when needed?

        # setting this to true makes the model store the means (using the embbing function at the time)
        # and number of thrown away data points for each class to be used as the mean prior
        self.store_old_means = store_old_means
        self.old_data_sum = None
        self.number_of_old_data_points = None
        self.old_data_means = None
        self.old_data_var_reg = None

    def _calc_means(self, memory, classes):

        if len(classes) == 0:
            return {}

        mem_means = {c: torch.zeros(self.rep_size, device=self.device) for c in classes}

        mem_dataloader = DataLoader([(data[0], c) for c in classes for data in memory[c]],
                                    shuffle=False, batch_size=self.calc_stat_batch_size)  # True
        for (X, Y) in mem_dataloader:
            X = X.to(self.device)
            # r = torch.sigmoid(self.model(X))
            r = self.model(X)
            for c in classes:
                mem_means[c] += r[Y==c].sum(0)
        for c in classes:
            mem_means[c] /= len(memory[c])

        return mem_means

    def _calc_means_r(self, r, memory, classes):

        mem_means = {c: torch.zeros(self.rep_size, device=self.device) for c in classes}
        i = 0
        for c in classes:
            class_r = r[i:i+len(memory[c])]
            mem_means[c] += class_r.mean(0)

            #num_of_mem_grads = min(max(len(memory[c]) // 10 + 1, 10), len(memory[c]))
            #indexes = i+torch.randperm(len(memory[c]))
            #grad_class_r = r[indexes[:num_of_mem_grads]]
            #no_grad_class_r = r[indexes[num_of_mem_grads:]]
            #no_grad_class_r = no_grad_class_r.detach()
            #no_grad_class_r.requires_grad = False
            #mem_means[c] = (grad_class_r.sum(0)+no_grad_class_r.sum(0))/len(memory[c])

            i += len(memory[c])

        return mem_means

    def _calc_mem_sum(self, memory, classes):

        if len(classes) == 0:
            return {}

        mem_sum = {c: torch.zeros(self.rep_size, device=self.device) for c in classes}

        mem_dataloader = DataLoader([(data[0], c) for c in classes for data in memory[c]],
                                    shuffle=True, batch_size=self.calc_stat_batch_size)
        for (X, Y) in mem_dataloader:
            X = X.to(self.device)
            # r = torch.sigmoid(self.model(X))
            r = self.model(X)
            for c in classes:
                mem_sum[c] += r[Y==c].sum(0)
        return mem_sum

    # returns a tensor of log probs for a batch of data
    # where the rows are the data points and the columns are the classes
    def _batch_log_probs(self, X, means, classes, memory):
        X = X.to(self.device)
        #r = torch.sigmoid(self.model(X))
        r = self.model(X)
        log_probs = {}
        for c in classes:
            inv_n = 1 / len(memory[c]) + 1
            k = -1 / (2 * self.s_coeff * inv_n)
            log_probs[c] = ((r - means[c]) ** 2).sum(1) * k - r.shape[1] / 2 * math.log(inv_n)
        # check to make sure this calcs correct thing
        return torch.stack([log_probs[c] for c in classes], dim=1)

    def _batch_log_probs_r(self, r, means, classes, memory):
        log_probs = {}
        for c in classes:
            inv_n = 1 / len(memory[c]) + 1
            k = -1 / (2 * self.s_coeff * inv_n)
            log_probs[c] = ((r - means[c]) ** 2).sum(1) * k - r.shape[1] / 2 * math.log(inv_n)
        # check to make sure this calcs correct thing
        return torch.stack([log_probs[c] for c in classes], dim=1)

    def _current_batch_loss(self, X, Y, means, classes, memory):
        loss = torch.zeros(1, device=self.device)
        if len(classes) == 0:
            return loss
        data_log_probs = self._batch_log_probs(X, means, classes, memory)
        z = torch.logsumexp(data_log_probs, 1)
        for i in range(Y.shape[0]):
            try:
                j = classes.index(Y[i].item())
                loss -= data_log_probs[i, j] - z[i]
            except ValueError:
                continue
        return 1/Y.shape[0]*loss

    def _current_batch_loss_r(self, r, Y, means, classes, memory):
        loss = torch.zeros(1, device=self.device)
        if len(classes) == 0:
            return loss
        data_log_probs = self._batch_log_probs_r(r, means, classes, memory)
        z = torch.logsumexp(data_log_probs, 1)
        for i in range(Y.shape[0]):
            try:
                j = classes.index(Y[i].item())
                loss -= data_log_probs[i, j] - z[i]
            except ValueError:
                continue
        return 1/Y.shape[0]*loss

    def _calc_random_replay_sample(self, per_task_replay_data, memory, num_of_samples):
        if num_of_samples == 0:
            return
        lengths = {}
        n = 0
        for y in memory:
            temp = len(self.per_class_mem[y]) - self.min_num_mem_points
            if temp > 0:
                lengths[y] = temp
                n += temp
        if n <= num_of_samples:
            indexes = list(range(n))
        else:
            indexes = random.sample(range(n), k=num_of_samples)
            indexes = sorted(indexes)
        class_grouped_indexes = {y: [] for y in lengths}
        acc = 0
        last_acc = 0
        for y in lengths:
            acc += lengths[y]
            for i, index in enumerate(indexes):
                if index < acc:
                    class_grouped_indexes[y].append(index-last_acc+self.min_num_mem_points)
                else:
                    i -= 1
                    break
            del indexes[:i+1]
            last_acc = acc
        for y in class_grouped_indexes:
            # random.shuffle(memory[y])  # this is a temp test
            for index in reversed(class_grouped_indexes[y]):
                per_task_replay_data[memory[y][index][1]].append((memory[y][index][0], y))
                del memory[y][index]

    # sample points from memory
    # currently do this in a basic manner might need to/want to change this to improve performance
    def _split_memory_for_per_task_replay(self):
        memory = {y: list(self.per_class_mem[y]) for y in self.per_class_mem}
        per_task_replay_data = {t: [] for t in range(self.task_id+1)}
        # self._internal_replay_sample_calc(per_task_replay_data, memory,
        #                                  [1/len(self.per_class_mem)]*len(self.per_class_mem), self.mem_batch_size)
        self._calc_random_replay_sample(per_task_replay_data, memory, self.mem_batch_size)

        # calculate all the tasks and associated classes in replay buffer
        memory_classes = {}
        for t in range(self.task_id+1):
            if per_task_replay_data[t] == []:
                del per_task_replay_data[t]
            else:
                memory_classes[t] = self.calc_classes(t, self.task_stream.classes)

        return per_task_replay_data, memory, memory_classes

    def _calc_task_inc_replay_loss(self, per_task_replay_data, means, memory_classes, memory):
        replay_loss = torch.zeros(1, device=self.device)
        n = 0

        for t in per_task_replay_data:
            X, Y = tuple(zip(*per_task_replay_data[t]))
            per_task_replay_data[t] = torch.stack(X), torch.tensor(Y)

        X = torch.cat([per_task_replay_data[t][0] for t in per_task_replay_data])
        X = X.to(self.device)
        r = self.model(X)

        #for t in per_task_replay_data:
        #   data_loader = DataLoader(per_task_replay_data[t], batch_size=self.batch_size)
        #    n += len(per_task_replay_data[t])
        #    for (X, Y) in data_loader:
        #        log_probs = self._batch_log_probs(X, means, memory_classes[t], memory)
        #        z = torch.logsumexp(log_probs, 1)
        #        for i in range(Y.shape[0]):
        #            j = memory_classes[t].index(Y[i].item())
        #            replay_loss -= log_probs[i, j] - z[i]
        #return 1 / n * replay_loss if n != 0 else replay_loss

        for t in per_task_replay_data:
            Y = per_task_replay_data[t][1]
            log_probs = self._batch_log_probs_r(r[n:n+Y.shape[0]], means, memory_classes[t], memory)
            z = torch.logsumexp(log_probs, 1)
            for i in range(Y.shape[0]):
                j = memory_classes[t].index(Y[i].item())
                replay_loss -= log_probs[i, j] - z[i]
            n += Y.shape[0]
        return 1/n*replay_loss if n != 0 else replay_loss

    def _calc_task_inc_replay_loss_r(self, r, per_task_replay_data, means, memory_classes, memory):
        replay_loss = torch.zeros(1, device=self.device)
        n = 0

        for t in per_task_replay_data:
            Y = per_task_replay_data[t][1]
            log_probs = self._batch_log_probs_r(r[n:n+Y.shape[0]], means, memory_classes[t], memory)
            z = torch.logsumexp(log_probs, 1)
            for i in range(Y.shape[0]):
                j = memory_classes[t].index(Y[i].item())
                replay_loss -= log_probs[i, j] - z[i]
            n += Y.shape[0]
        return 1/n*replay_loss if n != 0 else replay_loss

    def loss_fn(self, X, Y):
        # in task inc setting per_class_mem also contains task indexes with every x, in a pair (x, t)

        # if memory is empty cannot compute loss
        if len(self.per_class_mem) == 0:
            # at start of experiment init our old mean values (used in mean prior) if selected to do so
            if self.store_old_means:
                self.old_data_means = {y: torch.zeros(self.rep_size, device=self.device) for y in self.task_stream.classes}
                self.old_data_sum = {y: torch.zeros(self.rep_size, device=self.device) for y in self.task_stream.classes}
                self.number_of_old_data_points = {y: 0 for y in self.task_stream.classes}
                self.old_data_var_reg = {y: 1.0 for y in self.task_stream.classes}
            return torch.zeros(1, device=self.device, requires_grad=True)

        # sample points from memory
        # currently do this in a basic manner might need to/want to change this to improve performance
        per_task_replay_data, memory, memory_classes = self._split_memory_for_per_task_replay()

        # calculate the current task classes
        current_task_classes = self.calc_classes(self.task_id, self.task_stream.classes)

        # remove classes which do not have enough data in memory yet
        temp = []
        for c in current_task_classes:
            if c in memory.keys() and \
                    len(memory[c]) >= self.min_num_mem_points:
                temp.append(c)
        current_task_classes = temp

        classes = set()
        for t in memory_classes:
            if t == self.task_id:
                memory_classes[t] = current_task_classes
            classes.update(memory_classes[t])
        classes.update(current_task_classes)
        classes = list(classes)

        if len(classes) == 0:
            return torch.zeros(1, device=self.device, requires_grad=True)

        for t in per_task_replay_data:
            X_replay, Y_replay = tuple(zip(*per_task_replay_data[t]))
            per_task_replay_data[t] = torch.stack(X_replay), torch.tensor(Y_replay)

        X_replay = torch.cat([per_task_replay_data[t][0] for t in per_task_replay_data])
        X_replay = X_replay.to(self.device)
        X_mem = torch.stack([data[0] for c in classes for data in memory[c]])  # True
        X_mem = X_mem.to(self.device)

        X = X.to(self.device)

        n_mem = X_mem.shape[0]
        n_replay = X_replay.shape[0]

        r = self.model(torch.cat((X_mem, X_replay, X)))
        r_mem = r[0:n_mem]
        r_replay = r[n_mem:n_mem+n_replay]
        r_batch = r[n_mem+n_replay:]

        means = self._calc_means_r(r_mem, memory, classes)

        replay_loss = self._calc_task_inc_replay_loss_r(r_replay, per_task_replay_data, means, memory_classes, memory)

        replay_loss = torch.zeros(1, device=self.device, requires_grad=True) if self.mem_batch_size == 0 else replay_loss

        batch_loss = self._current_batch_loss_r(r_batch, Y, means, current_task_classes, memory)

        return batch_loss + replay_loss

    def eval(self):
        super().eval()

        # calculate memory statistics
        self.mem_means = self._calc_means(self.per_class_mem, self.per_class_mem.keys())

        for y in self.per_class_mem:
            self.mem_means[y] = self.mem_means[y].detach()
            self.mem_means[y].requires_grad = False

    def _calc_predictions(self, X, classes):
        log_probs = self._batch_log_probs(X, self.mem_means, classes, self.per_class_mem)
        classes = torch.tensor(classes, device=self.device)
        return classes[log_probs.argmax(1)]

    def calc_class_inc_model_predictions(self, X):
        # use all seen classes in prediction as test time according to one version of class inc setup
        classes = list(self.per_class_mem.keys())
        return self._calc_predictions(X, classes)

    def calc_task_inc_model_predictions(self, X):
        # calculate the current task classes
        classes = self.calc_classes(self.task_id, self.task_stream.classes)
        return self._calc_predictions(X, classes)

    def predict(self, X):
        if self.training:
            return X
        else:
            return self.calc_task_inc_model_predictions(X) if self.use_task_inc_loss \
                else self.calc_class_inc_model_predictions(X)

    def _get_mem_representations(self, class_memory):
        mem_dataloader = DataLoader(class_memory, batch_size=self.calc_stat_batch_size)  # should I use stat_batch_size here?
        r_class_memory = []
        for (X, _) in mem_dataloader:
            X = X.to(self.device)
            #r = torch.sigmoid(self.model(X))
            r = self.model(X)
            r = r.detach()  # don't want to pass grad through embedding
            r.requires_grad = False
            for i in range(X.shape[0]):
                r_class_memory.append(r[i])

        return r_class_memory

    def _mem_select_loss(self, r, mem_selector, mem_mean, new_mem_size):
        n = r.shape[0]
        # calculate the weighted mean and variance

        start, stop = 0, 0  # used to select what weights to use
        stop = start + r.shape[0]
        weighted_r = mem_selector(r, start, stop)  # just calculates weights * r
        weighted_mean = weighted_r.sum(0)
        #start = stop
        mod_n = mem_selector.calc_weights(0, n).sum() if self.fixed_var else n
        weighted_mean /= mod_n

        # calc squared error loss with l_1 reg (where number of sample is used as the reg coefficient)
        # is this a good value of reg coefficient?
        return ((mem_mean-weighted_mean)**2).sum() + n/new_mem_size*mem_selector.calc_weights(0, n).abs().sum()

    def _mem_select_loss_all(self, r, mem_selector, mem_means, new_mem_sizes, splits):
        n = r.shape[0]
        # calculate the weighted mean and variance

        start, stop = 0, 0  # used to select what weights to use
        stop = start + r.shape[0]
        weighted_r = mem_selector(r, start, stop)  # just calculates weights * r
        #start = stop
        i = 0
        weighted_class_means = []
        for y in splits:
            weighted_class_mean = weighted_r[i:i+splits[y]].sum(0)
            mod_n = mem_selector.calc_weights(i, i+splits[y]).sum()
            weighted_class_mean /= mod_n
            weighted_class_means.append(weighted_class_mean)
            i += splits[y]
        weighted_mean = torch.cat(weighted_class_means)
        se = (mem_means-weighted_mean)**2
        i = 0
        loss = torch.zeros(1, device=self.device, requires_grad=True)
        for y in splits:
            loss = loss + se[i:i+splits[y]].sum() + n/new_mem_sizes[y]*mem_selector.calc_weights(i, i+splits[y]).abs().sum()
            i += splits[y]

        # calc squared error loss with l_1 reg (where number of sample is used as the reg coefficient)
        # is this a good value of reg coefficient?
        return loss

    # calculates the new memory after training on a batch
    def calc_new_mem(self, X, Y):

        self.model.eval()

        # add batch to memory
        for i in range(Y.shape[0]):
            y = Y[i].item()
            if y in self.per_class_mem:
                self.per_class_mem[y].append((X[i].to("cpu"), self.task_id))
            else:
                self.per_class_mem[y] = [(X[i].to("cpu"), self.task_id)]

        # calculate the number of memory locations which are going to be used by the classes,
        # keeping it as balanced as possible but also using up all the memory if possible
        per_class_mem_size = np.array([0]*len(self.task_stream.classes))
        classes = np.array(list(self.per_class_mem.keys()))
        start = 0
        end = 0
        remaining_mem = self.mem_size
        ns = [len(self.per_class_mem[y]) for y in self.per_class_mem]
        sorted_ns_indices = sorted(range(len(ns)), key=lambda j: ns[j])
        while start < len(ns):
            min_n = ns[sorted_ns_indices[start]]
            mem_n = remaining_mem//(len(ns)-start)
            if min_n > mem_n:
                per_class_mem_size[classes[sorted_ns_indices[start:]]] = mem_n
                break
            more_left = False
            for end, sorted_index in enumerate(sorted_ns_indices[start:]):
                if ns[sorted_index] != min_n:
                    more_left = True
                    break
            if not more_left:
                per_class_mem_size[classes[sorted_ns_indices[start:]]] = min_n
                break
            end += start
            per_class_mem_size[classes[sorted_ns_indices[start:end]]] = min_n
            remaining_mem -= min_n*(end-start)
            start = end

        batch_classes = []
        for i in range(Y.shape[0]):
            batch_classes.append(Y[i].item())

        # calculate the class whose memory needs to be updated using computationally heavy fitting mechanism
        fit_classes = []
        for y in self.per_class_mem:
            n = len(self.per_class_mem[y])

            # if statements used to simplify computation where possible
            # class_sorted currently dict of class to bool could make more mem efficient
            if n <= per_class_mem_size[y]:
                if y in batch_classes:
                    self.class_sorted[y] = False
                continue
            if y not in batch_classes:
                if y in self.class_sorted and self.class_sorted[y]:
                    self.per_class_mem[y] = [self.per_class_mem[y][i] for i in range(per_class_mem_size[y])]
                    continue
            fit_classes.append(y)

        if len(fit_classes) == 0:
            self.model.train()
            return

        # calculate memory statistics
        mem_means = self._calc_means(self.per_class_mem, fit_classes)

        # don't backprop through memory statistics
        # check if this work correctly
        #mem_means = torch.cat([mem_means[y] for y in fit_classes])
        #mem_means = mem_means.detach()
        #mem_means.requires_grad = False
        for y in fit_classes:
            mem_means[y] = mem_means[y].detach()
            mem_means[y].requires_grad = False



        # for each class calculate new memory

        #print(len(fit_classes))

        X = torch.cat([torch.stack(get_first_elements_tuples(self.per_class_mem[y])) for y in fit_classes])
        X = X.to(self.device)
        r = self.model(X)
        r_per_class_memory = {}
        r = r.detach()  # don't want to pass grad through embedding
        r.requires_grad = False

        #mem_selector = MemSelectModel(r.shape[0], device=self.device)
        #optimiser = torch.optim.SGD(mem_selector.parameters(), lr=self.mem_lr)
        #for i in range(self.num_of_iter_new_mem):
        #    optimiser.zero_grad()
        #    loss = self._mem_select_loss_all(r, mem_selector, mem_means, per_class_mem_size,
        #                                     {y: len(self.per_class_mem[y]) for y in fit_classes})
        #    loss.backward()
        #    optimiser.step()
        #i = 0
        #for y in fit_classes:
        #    _, indices = torch.topk(mem_selector.weights[i:i+len(self.per_class_mem[y])], per_class_mem_size[y],
        #                            sorted=True)
        #    i += len(self.per_class_mem[y])
        #    self.per_class_mem[y] = [self.per_class_mem[y][i] for i in indices]


        i = 0
        for y in fit_classes:
            r_per_class_memory[y] = r[i:i+len(self.per_class_mem[y])]
            i += len(self.per_class_mem[y])

        for y in fit_classes:
            n = len(self.per_class_mem[y])

            # this computes the memory items to store by matching parameters
            mem_selector = MemSelectModel(n, device=self.device)
            optimiser = torch.optim.SGD(mem_selector.parameters(), lr=self.mem_lr)
            for i in range(self.num_of_iter_new_mem):
                optimiser.zero_grad()
                loss = self._mem_select_loss(r_per_class_memory[y], mem_selector, mem_means[y], per_class_mem_size[y])
                loss.backward()
                optimiser.step()
            _, indices = torch.topk(mem_selector.weights, per_class_mem_size[y], sorted=True)

            self.per_class_mem[y] = [self.per_class_mem[y][i] for i in indices]
            #self.per_class_mem[y] = random.sample(self.per_class_mem[y], k=per_class_mem_size[y])

        self.model.train()

    def after_optimiser_step(self):
        self.calc_new_mem(*self.batch)


# this class is a wrapper to weighting our representations such that we can use a torch optimiser to fit the weights
class MemSelectModel(torch.nn.Module):
    def __init__(self, num_of_data_points, device):
        super().__init__()
        weights = torch.ones(num_of_data_points, device=device)
        torch.nn.init.normal_(weights, mean=1.25, std=0.1)
        self.weights = torch.nn.Parameter(weights)

    def calc_weights(self, start, stop):
        return torch.sigmoid(2*self.weights[start:stop, None])

    def forward(self, r, start, stop):
        return self.calc_weights(start, stop)*r  # makes the weights broadcastable across rows


def t_sne_plot(means, per_class_r):
    tsne = TSNE(2, verbose=1)
    data = [means[y] for y in means]
    classes = list(means.keys())
    for y in classes:
        data += per_class_r[y]
    tsne_proj = tsne.fit_transform(data)
    # Plot those points as a scatter plot and label them based on the pred labels
    cmap = cm.get_cmap('tab20')
    fig, ax = plt.subplots(figsize=(8, 8))
    offset = len(classes)
    for i, y in enumerate(classes):
        indices = [i] + list(range(offset, offset+len(per_class_r[y])))
        offset += len(per_class_r[y])
        ax.scatter(tsne_proj[indices, 0], tsne_proj[indices, 1], c=np.array(cmap(y)).reshape(1, 4), label=y,
                   alpha=0.5)
    ax.legend(fontsize='large', markerscale=2)
    plt.show()


#helper function to exract first lemet of tuple list
def get_first_elements_tuples(xs):
    return [x[0] for x in xs]
